import torch
from torch import nn

class ProjectionLayer(nn.Module):
    def __init__(self, dim_model, vocab_size):
        super(ProjectionLayer, self).__init__()
        self.proj = nn.Linear(dim_model, vocab_size)

    def forward(self, x):
        return torch.log_softmax(self.proj(x), dim=-1) 